{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PaddleOCR在DJL 上的實現\n",
    "在這個教程裡，我們會展示利用 PaddleOCR 下載預訓練好文字處理模型並對指定的照片進行文學文字檢測 (OCR)。這個教程總共會分成三個部分:\n",
    "\n",
    "- 文字區塊檢測: 從圖片檢測出文字區塊\n",
    "- 文字角度檢測: 確認文字是否需要旋轉\n",
    "- 文字識別: 確認區塊內的文字\n",
    "\n",
    "## 導入相關環境依賴及子類別\n",
    "在這個例子中的前處理飛槳深度學習引擎需要搭配DJL混合模式進行深度學習推理，原因是引擎本身沒有包含ND數組操作，因此需要藉用其他引擎的數組操作能力來完成。這邊我們導入Pytorch來做協同的前處理工作:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// %mavenRepo snapshots https://oss.sonatype.org/content/repositories/snapshots/\n",
    "\n",
    "%maven ai.djl:api:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-model-zoo:0.12.0\n",
    "%maven ai.djl.paddlepaddle:paddlepaddle-native-auto:2.0.2\n",
    "%maven org.slf4j:slf4j-api:1.7.26\n",
    "%maven org.slf4j:slf4j-simple:1.7.26\n",
    "\n",
    "// second engine to do preprocessing and postprocessing\n",
    "%maven ai.djl.pytorch:pytorch-engine:0.12.0\n",
    "%maven ai.djl.pytorch:pytorch-native-auto:1.8.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ai.djl.*;\n",
    "import ai.djl.inference.Predictor;\n",
    "import ai.djl.modality.Classifications;\n",
    "import ai.djl.modality.cv.Image;\n",
    "import ai.djl.modality.cv.ImageFactory;\n",
    "import ai.djl.modality.cv.output.*;\n",
    "import ai.djl.modality.cv.util.NDImageUtils;\n",
    "import ai.djl.ndarray.*;\n",
    "import ai.djl.ndarray.types.DataType;\n",
    "import ai.djl.ndarray.types.Shape;\n",
    "import ai.djl.repository.zoo.*;\n",
    "import ai.djl.paddlepaddle.zoo.cv.objectdetection.PpWordDetectionTranslator;\n",
    "import ai.djl.paddlepaddle.zoo.cv.imageclassification.PpWordRotateTranslator;\n",
    "import ai.djl.paddlepaddle.zoo.cv.wordrecognition.PpWordRecognitionTranslator;\n",
    "import ai.djl.translate.*;\n",
    "import java.util.concurrent.ConcurrentHashMap;"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 圖片讀取\n",
    "首先讓我們載入這次教程會用到的機票範例圖片:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String url = \"https://resources.djl.ai/images/flight_ticket.jpg\";\n",
    "Image img = ImageFactory.getInstance().fromUrl(url);\n",
    "img.getWrappedImage();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文字區塊檢測\n",
    "我們首先從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-detection-model-to-inference-model) 開發套件中讀取文字檢測的模型，之後我們可以生成一個DJL `Predictor` 並將其命名為 `detector`.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria1 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, DetectedObjects.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/det_db.zip\")\n",
    "                .optTranslator(new PpWordDetectionTranslator(new ConcurrentHashMap<String, String>()))\n",
    "                .build();\n",
    "var detectionModel = criteria1.loadModel();\n",
    "var detector = detectionModel.newPredictor();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接著我們檢測出圖片中的文字區塊，這個模型的原始輸出是含有標註所有文字區域的圖算法(Bitmap)，我們可以利用`PpWordDetectionTranslator` 函式將圖算法的輸出轉成長方形的方框來裁剪圖片"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var detectedObj = detector.predict(img);\n",
    "Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);\n",
    "newImage.drawBoundingBoxes(detectedObj);\n",
    "newImage.getWrappedImage();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "如上所示，所標註的文字區塊都非常窄，且沒有包住所有完整的文字區塊。讓我們嘗試使用`extendRect`函式來擴展文字框的長寬到需要的大小, 再利用 `getSubImage` 裁剪並擷取出文子區塊。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image getSubImage(Image img, BoundingBox box) {\n",
    "    Rectangle rect = box.getBounds();\n",
    "    double[] extended = extendRect(rect.getX(), rect.getY(), rect.getWidth(), rect.getHeight());\n",
    "    int width = img.getWidth();\n",
    "    int height = img.getHeight();\n",
    "    int[] recovered = {\n",
    "        (int) (extended[0] * width),\n",
    "        (int) (extended[1] * height),\n",
    "        (int) (extended[2] * width),\n",
    "        (int) (extended[3] * height)\n",
    "    };\n",
    "    return img.getSubimage(recovered[0], recovered[1], recovered[2], recovered[3]);\n",
    "}\n",
    "\n",
    "double[] extendRect(double xmin, double ymin, double width, double height) {\n",
    "    double centerx = xmin + width / 2;\n",
    "    double centery = ymin + height / 2;\n",
    "    if (width > height) {\n",
    "        width += height * 2.0;\n",
    "        height *= 3.0;\n",
    "    } else {\n",
    "        height += width * 2.0;\n",
    "        width *= 3.0;\n",
    "    }\n",
    "    double newX = centerx - width / 2 < 0 ? 0 : centerx - width / 2;\n",
    "    double newY = centery - height / 2 < 0 ? 0 : centery - height / 2;\n",
    "    double newWidth = newX + width > 1 ? 1 - newX : width;\n",
    "    double newHeight = newY + height > 1 ? 1 - newY : height;\n",
    "    return new double[] {newX, newY, newWidth, newHeight};\n",
    "}"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "讓我們輸出其中一個文字區塊"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "List<DetectedObjects.DetectedObject> boxes = detectedObj.items();\n",
    "var sample = getSubImage(img, boxes.get(5).getBoundingBox());\n",
    "sample.getWrappedImage();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文字角度檢測\n",
    "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-angle-classification-model-to-inference-model) 輸出這個模型並確認圖片及文字是否需要旋轉。以下的代碼會讀入這個模型並生成a `rotateClassifier` 子類別"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria2 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, Classifications.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/cls.zip\")\n",
    "                .optTranslator(new PpWordRotateTranslator())\n",
    "                .build();\n",
    "var rotateModel = criteria2.loadModel();\n",
    "var rotateClassifier = rotateModel.newPredictor();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 文字識別\n",
    "\n",
    "我們從 [PaddleOCR](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.0/doc/doc_en/inference_en.md#convert-recognition-model-to-inference-model) 輸出這個模型並識別圖片中的文字, 我們一樣仿造上述的步驟讀取這個模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "var criteria3 = Criteria.builder()\n",
    "                .optEngine(\"PaddlePaddle\")\n",
    "                .setTypes(Image.class, String.class)\n",
    "                .optModelUrls(\"https://resources.djl.ai/test-models/paddleOCR/mobile/rec_crnn.zip\")\n",
    "                .optTranslator(new PpWordRecognitionTranslator())\n",
    "                .build();\n",
    "var recognitionModel = criteria3.loadModel();\n",
    "var recognizer = recognitionModel.newPredictor();"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接著我們可以試著套用這兩個模型在先前剪裁好的文字區塊上"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "System.out.println(rotateClassifier.predict(sample));\n",
    "recognizer.predict(sample);"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最後我們把這些模型串連在一起並套用在整張圖片上看看結果會如何。DJL提供了豐富的影像工具包讓你可以從圖片中擷取出文字並且完美呈現"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Image rotateImg(Image image) {\n",
    "    try (NDManager manager = NDManager.newBaseManager()) {\n",
    "        NDArray rotated = NDImageUtils.rotate90(image.toNDArray(manager), 1);\n",
    "        return ImageFactory.getInstance().fromNDArray(rotated);\n",
    "    }\n",
    "}\n",
    "\n",
    "List<String> names = new ArrayList<>();\n",
    "List<Double> prob = new ArrayList<>();\n",
    "List<BoundingBox> rect = new ArrayList<>();\n",
    "\n",
    "for (int i = 0; i < boxes.size(); i++) {\n",
    "    Image subImg = getSubImage(img, boxes.get(i).getBoundingBox());\n",
    "    if (subImg.getHeight() * 1.0 / subImg.getWidth() > 1.5) {\n",
    "        subImg = rotateImg(subImg);\n",
    "    }\n",
    "    Classifications.Classification result = rotateClassifier.predict(subImg).best();\n",
    "    if (\"Rotate\".equals(result.getClassName()) && result.getProbability() > 0.8) {\n",
    "        subImg = rotateImg(subImg);\n",
    "    }\n",
    "    String name = recognizer.predict(subImg);\n",
    "    names.add(name);\n",
    "    prob.add(-1.0);\n",
    "    rect.add(boxes.get(i).getBoundingBox());\n",
    "}\n",
    "newImage.drawBoundingBoxes(new DetectedObjects(names, prob, rect));\n",
    "newImage.getWrappedImage();"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "11.0.5+10-LTS"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
